package biz

import (
	"context"
	"strconv"

	"gitee.com/johnnymol/jwtsession"
	"github.com/go-kratos/kratos/v2/errors"
	"github.com/go-kratos/kratos/v2/log"
	"github.com/go-kratos/kratos/v2/transport"
	"github.com/go-kratos/kratos/v2/transport/http"
)

type AppJwtSession struct {
	store *jwtsession.RedisStore
	log   *log.Helper
}

const AppJwtSessionName = "Ztime-Auth"

const Jwt_Session_Field_Uid = "uid"

func NewAppJwtSession(store *jwtsession.RedisStore, logger log.Logger) *AppJwtSession {
	return &AppJwtSession{store: store, log: log.NewHelper(logger)}
}

func (s *AppJwtSession) transportFromContext(ctx context.Context) (*http.Transport, error) {
	if tr, ok := transport.FromServerContext(ctx); ok {
		if ht, ok := tr.(*http.Transport); ok {
			return ht, nil
		}
	}
	return nil, errors.InternalServer("INTERNAL_ERROR", "get http Transport from session error")
}

func (s *AppJwtSession) KratosSession(ctx context.Context) (*jwtsession.Session, error) {
	httpTransport, err := s.transportFromContext(ctx)
	if err != nil {
		return nil, err
	}

	session := jwtsession.AutoSession(ctx, s.store, httpTransport.Request(), AppJwtSessionName, jwtsession.WithInspector(jwtsession.HeaderInspector{}))
	return session, nil
}

func (s *AppJwtSession) GetSession(r *http.Request) (*jwtsession.Session, error) {
	session := jwtsession.AutoSession(r.Context(), s.store, r, AppJwtSessionName, jwtsession.WithInspector(jwtsession.HeaderInspector{}))
	return session, nil
}

func (s *AppJwtSession) KratosWritterFunc(transport *http.Transport) jwtsession.WritterFunc {
	return func(ctx context.Context, name, value string) error {
		transport.ReplyHeader().Set(name, value)
		return nil
	}
}

func (s *AppJwtSession) HttpWritterFunc(w http.ResponseWriter) jwtsession.WritterFunc {
	return func(ctx context.Context, name, value string) error {
		w.Header().Set(name, value)
		return nil
	}
}

// kratos
func (s *AppJwtSession) SaveSessions(ctx context.Context) error {
	httpTransport, err := s.transportFromContext(ctx)
	if err != nil {
		return err
	}
	writterFunc := s.KratosWritterFunc(httpTransport)
	err = jwtsession.Save(ctx, httpTransport.Request(), writterFunc)
	if err != nil {
		s.log.Errorf("保存session到 redisstore 失败了:err: %+v\n", err)
	}
	return err
}

// 保存单个 session
func (s *AppJwtSession) SaveKratosSession(ctx context.Context, session *jwtsession.Session) error {
	httpTransport, err := s.transportFromContext(ctx)
	if err != nil {
		return err
	}
	writterFunc := s.KratosWritterFunc(httpTransport)
	err = session.Save(ctx, writterFunc)
	if err != nil {
		s.log.Errorf("保存session到 redisstore 失败了:err: %+v\n", err)
	}
	return err
}

func (s *AppJwtSession) SaveSession(ctx context.Context, session *jwtsession.Session, w http.ResponseWriter) error {
	writterFunc := s.HttpWritterFunc(w)
	err := session.Save(ctx, writterFunc)
	if err != nil {
		s.log.Errorf("保存session到 redisstore 失败了:err: %+v\n", err)
	}
	return err
}

// kratos框架下获取 session 中的值
func (s *AppJwtSession) GetKratosSessionVal(ctx context.Context, name string) (string, error) {
	session, err := s.KratosSession(ctx)
	if err != nil {
		return "", errors.Unauthorized("UNAUTHORIZED", "初始化session失败了")
	}

	val, err := session.Get(name)
	if err != nil {
		return "", err
	}
	return val, nil
}

// kratos框架下设置 session 中的值
func (s *AppJwtSession) SetKratosSessionVal(ctx context.Context, name string, val string) error {
	session, err := s.KratosSession(ctx)
	if err != nil {
		return errors.Unauthorized("UNAUTHORIZED", "初始化session失败了")
	}
	session.Add(name, val)
	return s.SaveKratosSession(ctx, session)
}

// 获取 kratos 中的 uid
func (s *AppJwtSession) GetKratosSessionUid(ctx context.Context) (int64, error) {
	val, err := s.GetKratosSessionVal(ctx, Jwt_Session_Field_Uid)
	if err != nil {
		return 0, err
	}

	if val == "" {
		return 0, nil
	}
	return strconv.ParseInt(val, 10, 64)
}

// 设置 kratos 框架下的uid
func (s *AppJwtSession) SetKratosSessionUid(ctx context.Context, uid int64) error {
	return s.SetKratosSessionVal(ctx, Jwt_Session_Field_Uid, strconv.FormatInt(uid, 10))
}

// 获取 http 框架下获取 session 的方法
func (s *AppJwtSession) GetSessionUid(r *http.Request) (int64, error) {
	session, err := s.GetSession(r)
	if err != nil {
		return 0, err
	}

	val, err := session.Get(Jwt_Session_Field_Uid)
	if err != nil {
		return 0, err
	}
	return strconv.ParseInt(val, 10, 64)
}
